1b64c2
@@ -28,12 +28,14 @@
 import org.apache.hadoop.hive.ql.exec.vector.expressions.gen.FuncAbsLongToLong;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.serde2.io.DoubleWritable;
+import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
 import org.apache.hadoop.io.IntWritable;
@@ -55,6 +57,7 @@
   private final DoubleWritable resultDouble = new DoubleWritable();
   private final LongWritable resultLong = new LongWritable();
   private final IntWritable resultInt = new IntWritable();
+  private final HiveDecimalWritable resultDecimal = new HiveDecimalWritable();
   private transient PrimitiveObjectInspector argumentOI;
   private transient Converter inputConverter;
 
@@ -94,9 +97,10 @@
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumen
       outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
       break;
     case DECIMAL:
+      outputOI = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(
+          ((PrimitiveObjectInspector) arguments[0]).getTypeInfo());
       inputConverter = ObjectInspectorConverters.getConverter(arguments[0],
-          PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector);
-      outputOI = PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector;
+          outputOI);
       break;
     default:
       throw new UDFArgumentException(
@@ -129,11 +133,15 @@
public Object evaluate(DeferredObject[] arguments) throws HiveException {
       resultDouble.set(Math.abs(((DoubleWritable) valObject).get()));
       return resultDouble;
     case DECIMAL:
-      return PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector.set(
-          PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector
-              .create(HiveDecimal.ZERO),
-          PrimitiveObjectInspectorUtils.getHiveDecimal(valObject,
-              argumentOI).abs());
+      HiveDecimalObjectInspector decimalOI =
+          (HiveDecimalObjectInspector) argumentOI;
+      HiveDecimalWritable val = decimalOI.getPrimitiveWritableObject(valObject);
+
+      if (val != null) {
+        resultDecimal.set(val.getHiveDecimal().abs());
+        val = resultDecimal;
+      }
+      return val;
     default:
       throw new UDFArgumentException(
           "ABS only takes SHORT/BYTE/INT/LONG/DOUBLE/FLOAT/STRING/DECIMAL types, got " + inputType);
